{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jxv6goXm7oGF"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "llMNufAK7nfK"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8Byow2J6LaPl"
      },
      "source": [
        "# tf.data: Build TensorFlow input pipelines"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kGXS3UWBBNoc"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/data\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/data.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/data.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/data.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Qo3HgDjbDcI"
      },
      "source": [
        "The `tf.data` API enables you to build complex input pipelines from simple,\n",
        "reusable pieces. For example, the pipeline for an image model might aggregate\n",
        "data from files in a distributed file system, apply random perturbations to each\n",
        "image, and merge randomly selected images into a batch for training. The\n",
        "pipeline for a text model might involve extracting symbols from raw text data,\n",
        "converting them to embedding identifiers with a lookup table, and batching\n",
        "together sequences of different lengths. The `tf.data` API makes it possible to\n",
        "handle large amounts of data, read from different data formats, and perform\n",
        "complex transformations.\n",
        "\n",
        "The `tf.data` API introduces a `tf.data.Dataset` abstraction that represents a\n",
        "sequence of elements, in which each element consists of one or more components.\n",
        "For example, in an image pipeline, an element might be a single training\n",
        "example, with a pair of tensor components representing the image and its label.\n",
        "\n",
        "There are two distinct ways to create a dataset:\n",
        "\n",
        "*   A data **source** constructs a `Dataset` from data stored in memory or in\n",
        "    one or more files.\n",
        "\n",
        "*   A data **transformation** constructs a dataset from one or more\n",
        "    `tf.data.Dataset` objects.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UJIEjEIBdf-h"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Y0JtWBNR9E5"
      },
      "outputs": [],
      "source": [
        "import pathlib\n",
        "import os\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "\n",
        "np.set_printoptions(precision=4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0l4a0ALxdaWF"
      },
      "source": [
        "## Basic mechanics\n",
        "<a id=\"basic-mechanics\"/>\n",
        "\n",
        "To create an input pipeline, you must start with a data *source*. For example,\n",
        "to construct a `Dataset` from data in memory, you can use\n",
        "`tf.data.Dataset.from_tensors()` or `tf.data.Dataset.from_tensor_slices()`.\n",
        "Alternatively, if your input data is stored in a file in the recommended\n",
        "TFRecord format, you can use `tf.data.TFRecordDataset()`.\n",
        "\n",
        "Once you have a `Dataset` object, you can *transform* it into a new `Dataset` by\n",
        "chaining method calls on the `tf.data.Dataset` object. For example, you can\n",
        "apply per-element transformations such as `Dataset.map()`, and multi-element\n",
        "transformations such as `Dataset.batch()`. See the documentation for\n",
        "`tf.data.Dataset` for a complete list of transformations.\n",
        "\n",
        "The `Dataset` object is a Python iterable. This makes it possible to consume its\n",
        "elements using a for loop:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0F-FDnjB6t6J"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pwJsRJ-FbDcJ"
      },
      "outputs": [],
      "source": [
        "for elem in dataset:\n",
        "  print(elem.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m0yy80MobDcM"
      },
      "source": [
        "Or by explicitly creating a Python iterator using `iter` and consuming its\n",
        "elements using `next`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "03w9oxFfbDcM"
      },
      "outputs": [],
      "source": [
        "it = iter(dataset)\n",
        "\n",
        "print(next(it).numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q4CgCL8qbDcO"
      },
      "source": [
        "Alternatively, dataset elements can be consumed using the `reduce`\n",
        "transformation, which reduces all elements to produce a single result. The\n",
        "following example illustrates how to use the `reduce` transformation to compute\n",
        "the sum of a dataset of integers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C2bHAeNxbDcO"
      },
      "outputs": [],
      "source": [
        "print(dataset.reduce(0, lambda state, value: state + value).numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B2Fzwt2nbDcR"
      },
      "source": [
        "<!-- TODO(jsimsa): Talk about `tf.function` support. -->\n",
        "\n",
        "<a id=\"dataset_structure\"></a>\n",
        "### Dataset structure\n",
        "\n",
        "A dataset produces a sequence of *elements*, where each element is\n",
        "the same (nested) structure of *components*. Individual components\n",
        "of the structure can be of any type representable by\n",
        "`tf.TypeSpec`, including `tf.Tensor`, `tf.sparse.SparseTensor`,\n",
        "`tf.RaggedTensor`, `tf.TensorArray`, or `tf.data.Dataset`.\n",
        "\n",
        "The Python constructs that can be used to express the (nested)\n",
        "structure of elements include `tuple`, `dict`, `NamedTuple`, and\n",
        "`OrderedDict`. In particular, `list` is not a valid construct for\n",
        "expressing the structure of dataset elements. This is because\n",
        "early tf.data users felt strongly about `list` inputs (e.g. passed\n",
        "to `tf.data.Dataset.from_tensors`) being automatically packed as\n",
        "tensors and `list` outputs (e.g. return values of user-defined\n",
        "functions) being coerced into a `tuple`. As a consequence, if you\n",
        "would like a `list` input to be treated as a structure, you need\n",
        "to convert it into `tuple` and if you would like a `list` output\n",
        "to be a single component, then you need to explicitly pack it\n",
        "using `tf.stack`.\n",
        "\n",
        "The `Dataset.element_spec` property allows you to inspect the type\n",
        "of each element component. The property returns a *nested structure*\n",
        "of `tf.TypeSpec` objects, matching the structure of the element,\n",
        "which may be a single component, a tuple of components, or a nested\n",
        "tuple of components. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mg0m1beIhXGn"
      },
      "outputs": [],
      "source": [
        "dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))\n",
        "\n",
        "dataset1.element_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cwyemaghhXaG"
      },
      "outputs": [],
      "source": [
        "dataset2 = tf.data.Dataset.from_tensor_slices(\n",
        "   (tf.random.uniform([4]),\n",
        "    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))\n",
        "\n",
        "dataset2.element_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1CL7aB0ahXn_"
      },
      "outputs": [],
      "source": [
        "dataset3 = tf.data.Dataset.zip((dataset1, dataset2))\n",
        "\n",
        "dataset3.element_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m5bz7R1xhX1f"
      },
      "outputs": [],
      "source": [
        "# Dataset containing a sparse tensor.\n",
        "dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))\n",
        "\n",
        "dataset4.element_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lVOPHur_hYQv"
      },
      "outputs": [],
      "source": [
        "# Use value_type to see the type of value represented by the element spec\n",
        "dataset4.element_spec.value_type"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r5xNsFFvhUnr"
      },
      "source": [
        "The `Dataset` transformations support datasets of any structure. When using the\n",
        "`Dataset.map()`, and `Dataset.filter()` transformations,\n",
        "which apply a function to each element, the element structure determines the\n",
        "arguments of the function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2myAr3Pxd-zF"
      },
      "outputs": [],
      "source": [
        "dataset1 = tf.data.Dataset.from_tensor_slices(\n",
        "    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))\n",
        "\n",
        "dataset1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "woPXMP14gUTg"
      },
      "outputs": [],
      "source": [
        "for z in dataset1:\n",
        "  print(z.numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "53PA4x6XgLar"
      },
      "outputs": [],
      "source": [
        "dataset2 = tf.data.Dataset.from_tensor_slices(\n",
        "   (tf.random.uniform([4]),\n",
        "    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))\n",
        "\n",
        "dataset2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ju4sNSebDcR"
      },
      "outputs": [],
      "source": [
        "dataset3 = tf.data.Dataset.zip((dataset1, dataset2))\n",
        "\n",
        "dataset3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BgxsfAS2g6gk"
      },
      "outputs": [],
      "source": [
        "for a, (b,c) in dataset3:\n",
        "  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M1s2K0g-bDcT"
      },
      "source": [
        "## Reading input data\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F3JG2f0h2683"
      },
      "source": [
        "### Consuming NumPy arrays\n",
        "\n",
        "See [Loading NumPy arrays](../tutorials/load_data/numpy.ipynb) for more examples.\n",
        "\n",
        "If all of your input data fits in memory, the simplest way to create a `Dataset`\n",
        "from them is to convert them to `tf.Tensor` objects and use\n",
        "`Dataset.from_tensor_slices()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NmaE6PjjhQ47"
      },
      "outputs": [],
      "source": [
        "train, test = tf.keras.datasets.fashion_mnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J6cNiuDBbDcU"
      },
      "outputs": [],
      "source": [
        "images, labels = train\n",
        "images = images/255\n",
        "\n",
        "dataset = tf.data.Dataset.from_tensor_slices((images, labels))\n",
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XkwrDHN5bDcW"
      },
      "source": [
        "Note: The above code snippet will embed the `features` and `labels` arrays\n",
        "in your TensorFlow graph as `tf.constant()` operations. This works well for a\n",
        "small dataset, but wastes memory---because the contents of the array will be\n",
        "copied multiple times---and can run into the 2GB limit for the `tf.GraphDef`\n",
        "protocol buffer."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pO4ua2gEmIhR"
      },
      "source": [
        "### Consuming Python generators\n",
        "\n",
        "Another common data source that can easily be ingested as a `tf.data.Dataset` is the python generator.\n",
        "\n",
        "Caution: While this is a convienient approach it has limited portability and scalibility. It must run in the same python process that created the generator, and is still subject to the Python [GIL](https://en.wikipedia.org/wiki/Global_interpreter_lock)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9njpME-jmDza"
      },
      "outputs": [],
      "source": [
        "def count(stop):\n",
        "  i = 0\n",
        "  while i<stop:\n",
        "    yield i\n",
        "    i += 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xwqLrjnTpD8Y"
      },
      "outputs": [],
      "source": [
        "for n in count(5):\n",
        "  print(n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D_BB_PhxnVVx"
      },
      "source": [
        "The `Dataset.from_generator` constructor converts the python generator to a fully functional `tf.data.Dataset`.\n",
        "\n",
        "The constructor takes a callable as input, not an iterator. This allows it to restart the generator when it reaches the end. It takes an optional `args` argument, which is passed as the callable's arguments.\n",
        "\n",
        "The `output_types` argument is required because `tf.data` builds a `tf.Graph` internally, and graph edges require a `tf.dtype`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GFga_OTwm0Je"
      },
      "outputs": [],
      "source": [
        "ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fel1SUuBnDUE"
      },
      "outputs": [],
      "source": [
        "for count_batch in ds_counter.repeat().batch(10).take(10):\n",
        "  print(count_batch.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wxy9hDMTq1zD"
      },
      "source": [
        "The `output_shapes` argument is not *required* but is highly recomended as many tensorflow operations do not support tensors with unknown rank. If the length of a particular axis is unknown or variable, set it as `None` in the `output_shapes`.\n",
        "\n",
        "It's also important to note that the `output_shapes` and `output_types` follow the same nesting rules as other dataset methods.\n",
        "\n",
        "Here is an example generator that demonstrates both aspects, it returns tuples of arrays, where the second array is a vector with unknown length."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "allFX1g8rGKe"
      },
      "outputs": [],
      "source": [
        "def gen_series():\n",
        "  i = 0\n",
        "  while True:\n",
        "    size = np.random.randint(0, 10)\n",
        "    yield i, np.random.normal(size=(size,))\n",
        "    i += 1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Ku26Yb9rcJX"
      },
      "outputs": [],
      "source": [
        "for i, series in gen_series():\n",
        "  print(i, \":\", str(series))\n",
        "  if i > 5:\n",
        "    break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LmkynGilx0qf"
      },
      "source": [
        "The first output is an `int32` the second is a `float32`.\n",
        "\n",
        "The first item is a scalar, shape `()`, and the second is a vector of unknown length, shape `(None,)` "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zDTfhEzhsliM"
      },
      "outputs": [],
      "source": [
        "ds_series = tf.data.Dataset.from_generator(\n",
        "    gen_series, \n",
        "    output_types=(tf.int32, tf.float32), \n",
        "    output_shapes=((), (None,)))\n",
        "\n",
        "ds_series"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WWxvSyQiyN0o"
      },
      "source": [
        "Now it can be used like a regular `tf.data.Dataset`. Note that when batching a dataset with a variable shape, you need to use `Dataset.padded_batch`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A7jEpj3As1lO"
      },
      "outputs": [],
      "source": [
        "ds_series_batch = ds_series.shuffle(20).padded_batch(10)\n",
        "\n",
        "ids, sequence_batch = next(iter(ds_series_batch))\n",
        "print(ids.numpy())\n",
        "print()\n",
        "print(sequence_batch.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_hcqOccJ1CxG"
      },
      "source": [
        "For a more realistic example, try wrapping `preprocessing.image.ImageDataGenerator` as a `tf.data.Dataset`.\n",
        "\n",
        "First download the data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g-_JCFRQ1CXM"
      },
      "outputs": [],
      "source": [
        "flowers = tf.keras.utils.get_file(\n",
        "    'flower_photos',\n",
        "    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
        "    untar=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UIjPhvQ87jUT"
      },
      "source": [
        "Create the `image.ImageDataGenerator`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vPCZeBQE5DfH"
      },
      "outputs": [],
      "source": [
        "img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "my4PxqfH26p6"
      },
      "outputs": [],
      "source": [
        "images, labels = next(img_gen.flow_from_directory(flowers))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hd96nH1w3eKH"
      },
      "outputs": [],
      "source": [
        "print(images.dtype, images.shape)\n",
        "print(labels.dtype, labels.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KvRwvt5E2rTH"
      },
      "outputs": [],
      "source": [
        "ds = tf.data.Dataset.from_generator(\n",
        "    lambda: img_gen.flow_from_directory(flowers), \n",
        "    output_types=(tf.float32, tf.float32), \n",
        "    output_shapes=([32,256,256,3], [32,5])\n",
        ")\n",
        "\n",
        "ds.element_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LcaULBCXj_2_"
      },
      "outputs": [],
      "source": [
        "for images, label in ds.take(1):\n",
        "  print('images.shape: ', images.shape)\n",
        "  print('labels.shape: ', labels.shape)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ma4XoYzih2f4"
      },
      "source": [
        "### Consuming TFRecord data\n",
        "\n",
        "See [Loading TFRecords](../tutorials/load_data/tfrecord.ipynb) for an end-to-end example.\n",
        "\n",
        "The `tf.data` API supports a variety of file formats so that you can process\n",
        "large datasets that do not fit in memory. For example, the TFRecord file format\n",
        "is a simple record-oriented binary format that many TensorFlow applications use\n",
        "for training data. The `tf.data.TFRecordDataset` class enables you to\n",
        "stream over the contents of one or more TFRecord files as part of an input\n",
        "pipeline."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LiatWUloRJc4"
      },
      "source": [
        "Here is an example using the test file from the French Street Name Signs (FSNS)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jZo_4fzdbDcW"
      },
      "outputs": [],
      "source": [
        "# Creates a dataset that reads all of the examples from two files.\n",
        "fsns_test_file = tf.keras.utils.get_file(\"fsns.tfrec\", \"https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "seD5bOH3RhBP"
      },
      "source": [
        "The `filenames` argument to the `TFRecordDataset` initializer can either be a\n",
        "string, a list of strings, or a `tf.Tensor` of strings. Therefore if you have\n",
        "two sets of files for training and validation purposes, you can create a factory\n",
        "method that produces the dataset, taking filenames as an input argument:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e2WV5d7DRUA-"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])\n",
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "62NC3vz9U8ww"
      },
      "source": [
        "Many TensorFlow projects use serialized `tf.train.Example` records in their TFRecord files. These need to be decoded before they can be inspected:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3tk29nlMl5P3"
      },
      "outputs": [],
      "source": [
        "raw_example = next(iter(dataset))\n",
        "parsed = tf.train.Example.FromString(raw_example.numpy())\n",
        "\n",
        "parsed.features.feature['image/text']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qJAUib10bDcb"
      },
      "source": [
        "### Consuming text data\n",
        "\n",
        "See [Loading Text](../tutorials/load_data/text.ipynb) for an end to end example.\n",
        "\n",
        "Many datasets are distributed as one or more text files. The\n",
        "`tf.data.TextLineDataset` provides an easy way to extract lines from one or more\n",
        "text files. Given one or more filenames, a `TextLineDataset` will produce one\n",
        "string-valued element per line of those files."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hQMoFu2TbDcc"
      },
      "outputs": [],
      "source": [
        "directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'\n",
        "file_names = ['cowper.txt', 'derby.txt', 'butler.txt']\n",
        "\n",
        "file_paths = [\n",
        "    tf.keras.utils.get_file(file_name, directory_url + file_name)\n",
        "    for file_name in file_names\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "il4cOjiVwj95"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.TextLineDataset(file_paths)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MevIbDiwy4MC"
      },
      "source": [
        "Here are the first few lines of the first file:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vpEHKyvHxu8A"
      },
      "outputs": [],
      "source": [
        "for line in dataset.take(5):\n",
        "  print(line.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lJyVw8ro7fey"
      },
      "source": [
        "To alternate lines between files use `Dataset.interleave`. This makes it easier to shuffle files together. Here are the first, second and third lines from each translation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1UCveWOt7fDE"
      },
      "outputs": [],
      "source": [
        "files_ds = tf.data.Dataset.from_tensor_slices(file_paths)\n",
        "lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)\n",
        "\n",
        "for i, line in enumerate(lines_ds.take(9)):\n",
        "  if i % 3 == 0:\n",
        "    print()\n",
        "  print(line.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2F_pOIDubDce"
      },
      "source": [
        "By default, a `TextLineDataset` yields *every* line of each file, which may\n",
        "not be desirable, for example, if the file starts with a header line, or contains comments. These lines can be removed using the `Dataset.skip()` or\n",
        "`Dataset.filter()` transformations. Here, you skip the first line, then filter to\n",
        "find only survivors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X6b20Gua2jPO"
      },
      "outputs": [],
      "source": [
        "titanic_file = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n",
        "titanic_lines = tf.data.TextLineDataset(titanic_file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5M1pauNT68B2"
      },
      "outputs": [],
      "source": [
        "for line in titanic_lines.take(10):\n",
        "  print(line.numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dEIP95cibDcf"
      },
      "outputs": [],
      "source": [
        "def survived(line):\n",
        "  return tf.not_equal(tf.strings.substr(line, 0, 1), \"0\")\n",
        "\n",
        "survivors = titanic_lines.skip(1).filter(survived)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "odQ4618h1XqD"
      },
      "outputs": [],
      "source": [
        "for line in survivors.take(10):\n",
        "  print(line.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x5z5B11UjDTd"
      },
      "source": [
        "### Consuming CSV data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ChDHNi3qbDch"
      },
      "source": [
        "See [Loading CSV Files](../tutorials/load_data/csv.ipynb), and [Loading Pandas DataFrames](../tutorials/load_data/pandas.ipynb) for more examples. \n",
        "\n",
        "The CSV file format is a popular format for storing tabular data in plain text.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kj28j5u49Bjm"
      },
      "outputs": [],
      "source": [
        "titanic_file = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ghvtmW40LM0B"
      },
      "outputs": [],
      "source": [
        "df = pd.read_csv(titanic_file)\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J9uBqt5oGsR-"
      },
      "source": [
        "If your data fits in memory the same `Dataset.from_tensor_slices` method works on dictionaries, allowing this data to be easily imported:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JmAMCiPJA0qO"
      },
      "outputs": [],
      "source": [
        "titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))\n",
        "\n",
        "for feature_batch in titanic_slices.take(1):\n",
        "  for key, value in feature_batch.items():\n",
        "    print(\"  {!r:20s}: {}\".format(key, value))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "47yippqaHFk6"
      },
      "source": [
        "A more scalable approach is to load from disk as necessary. \n",
        "\n",
        "The `tf.data` module provides methods to extract records from one or more CSV files that comply with [RFC 4180](https://tools.ietf.org/html/rfc4180).\n",
        "\n",
        "The `experimental.make_csv_dataset` function is the high level interface for reading sets of csv files. It supports column type inference and many other features, like batching and shuffling, to make usage simple."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHUDrM_s_brq"
      },
      "outputs": [],
      "source": [
        "titanic_batches = tf.data.experimental.make_csv_dataset(\n",
        "    titanic_file, batch_size=4,\n",
        "    label_name=\"survived\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TsZfhz79_Wlg"
      },
      "outputs": [],
      "source": [
        "for feature_batch, label_batch in titanic_batches.take(1):\n",
        "  print(\"'survived': {}\".format(label_batch))\n",
        "  print(\"features:\")\n",
        "  for key, value in feature_batch.items():\n",
        "    print(\"  {!r:20s}: {}\".format(key, value))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k_5N7CdNGYAa"
      },
      "source": [
        "You can use the `select_columns` argument if you only need a subset of columns."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H9KNHyDwF2Sc"
      },
      "outputs": [],
      "source": [
        "titanic_batches = tf.data.experimental.make_csv_dataset(\n",
        "    titanic_file, batch_size=4,\n",
        "    label_name=\"survived\", select_columns=['class', 'fare', 'survived'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7C2uosFnGIT8"
      },
      "outputs": [],
      "source": [
        "for feature_batch, label_batch in titanic_batches.take(1):\n",
        "  print(\"'survived': {}\".format(label_batch))\n",
        "  for key, value in feature_batch.items():\n",
        "    print(\"  {!r:20s}: {}\".format(key, value))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TSVgJJ1HJD6M"
      },
      "source": [
        "There is also a lower-level `experimental.CsvDataset` class which provides finer grained control. It does not support column type inference. Instead you must specify the type of each column. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wP1Y_NXA8bYl"
      },
      "outputs": [],
      "source": [
        "titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string] \n",
        "dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)\n",
        "\n",
        "for line in dataset.take(10):\n",
        "  print([item.numpy() for item in line])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oZSuLVsTbDcj"
      },
      "source": [
        "If some columns are empty, this low-level interface allows you to provide default values instead of column types."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qry-g90FMo2I"
      },
      "outputs": [],
      "source": [
        "%%writefile missing.csv\n",
        "1,2,3,4\n",
        ",2,3,4\n",
        "1,,3,4\n",
        "1,2,,4\n",
        "1,2,3,\n",
        ",,,"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d5_hbiE9bDck"
      },
      "outputs": [],
      "source": [
        "# Creates a dataset that reads all of the records from two CSV files, each with\n",
        "# four float columns which may have missing values.\n",
        "\n",
        "record_defaults = [999,999,999,999]\n",
        "dataset = tf.data.experimental.CsvDataset(\"missing.csv\", record_defaults)\n",
        "dataset = dataset.map(lambda *items: tf.stack(items))\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "__jc7iD9M9FC"
      },
      "outputs": [],
      "source": [
        "for line in dataset:\n",
        "  print(line.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z_4g0cIvbDcl"
      },
      "source": [
        "By default, a `CsvDataset` yields *every* column of *every* line of the file,\n",
        "which may not be desirable, for example if the file starts with a header line\n",
        "that should be ignored, or if some columns are not required in the input.\n",
        "These lines and fields can be removed with the `header` and `select_cols`\n",
        "arguments respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p2IF_K0obDcm"
      },
      "outputs": [],
      "source": [
        "# Creates a dataset that reads all of the records from two CSV files with\n",
        "# headers, extracting float data from columns 2 and 4.\n",
        "record_defaults = [999, 999] # Only provide defaults for the selected columns\n",
        "dataset = tf.data.experimental.CsvDataset(\"missing.csv\", record_defaults, select_cols=[1, 3])\n",
        "dataset = dataset.map(lambda *items: tf.stack(items))\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-5aLprDeRNb0"
      },
      "outputs": [],
      "source": [
        "for line in dataset:\n",
        "  print(line.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-CJfhb03koVN"
      },
      "source": [
        "### Consuming sets of files"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yAO7SZDSk57_"
      },
      "source": [
        "There are many datasets distributed as a set of files, where each file is an example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1dZwN3CS-jV2"
      },
      "outputs": [],
      "source": [
        "flowers_root = tf.keras.utils.get_file(\n",
        "    'flower_photos',\n",
        "    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
        "    untar=True)\n",
        "flowers_root = pathlib.Path(flowers_root)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4099UU8n-jHP"
      },
      "source": [
        "Note: these images are licensed CC-BY, see LICENSE.txt for details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FCyTYpmDs_jE"
      },
      "source": [
        "The root directory contains a directory for each class:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_2iCXsHu6jJH"
      },
      "outputs": [],
      "source": [
        "for item in flowers_root.glob(\"*\"):\n",
        "  print(item.name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ylj9fgkamgWZ"
      },
      "source": [
        "The files in each class directory are examples:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lAkQp5uxoINu"
      },
      "outputs": [],
      "source": [
        "list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))\n",
        "\n",
        "for f in list_ds.take(5):\n",
        "  print(f.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "91CPfUUJ_8SZ"
      },
      "source": [
        "Read the data using the `tf.io.read_file` function and extract the label from the path, returning `(image, label)` pairs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-xhBRgvNqRRe"
      },
      "outputs": [],
      "source": [
        "def process_path(file_path):\n",
        "  label = tf.strings.split(file_path, os.sep)[-2]\n",
        "  return tf.io.read_file(file_path), label\n",
        "\n",
        "labeled_ds = list_ds.map(process_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kxrl0lGdnpRz"
      },
      "outputs": [],
      "source": [
        "for image_raw, label_text in labeled_ds.take(1):\n",
        "  print(repr(image_raw.numpy()[:100]))\n",
        "  print()\n",
        "  print(label_text.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yEh46Ee0oSH5"
      },
      "source": [
        "<!--\n",
        "TODO(mrry): Add this section.\n",
        "\n",
        "### Handling text data with unusual sizes\n",
        "-->\n",
        "\n",
        "## Batching dataset elements\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gR-2xY-8oSH4"
      },
      "source": [
        "### Simple batching\n",
        "\n",
        "The simplest form of batching stacks `n` consecutive elements of a dataset into\n",
        "a single element. The `Dataset.batch()` transformation does exactly this, with\n",
        "the same constraints as the `tf.stack()` operator, applied to each component\n",
        "of the elements: i.e. for each component *i*, all elements must have a tensor\n",
        "of the exact same shape."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xB7KeceLoSH0"
      },
      "outputs": [],
      "source": [
        "inc_dataset = tf.data.Dataset.range(100)\n",
        "dec_dataset = tf.data.Dataset.range(0, -100, -1)\n",
        "dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))\n",
        "batched_dataset = dataset.batch(4)\n",
        "\n",
        "for batch in batched_dataset.take(4):\n",
        "  print([arr.numpy() for arr in batch])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LlV1tpFdoSH0"
      },
      "source": [
        "While `tf.data` tries to propagate shape information, the default settings of `Dataset.batch` result in an unknown batch size because the last batch may not be full. Note the `None`s in the shape:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yN7hn7OBoSHx"
      },
      "outputs": [],
      "source": [
        "batched_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "It1fPA3NoSHw"
      },
      "source": [
        "Use the `drop_remainder` argument to ignore that last batch, and get full shape propagation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BycWC7WCoSHt"
      },
      "outputs": [],
      "source": [
        "batched_dataset = dataset.batch(7, drop_remainder=True)\n",
        "batched_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mj9nRxFZoSHs"
      },
      "source": [
        "### Batching tensors with padding\n",
        "\n",
        "The above recipe works for tensors that all have the same size. However, many\n",
        "models (e.g. sequence models) work with input data that can have varying size\n",
        "(e.g. sequences of different lengths). To handle this case, the\n",
        "`Dataset.padded_batch` transformation enables you to batch tensors of\n",
        "different shape by specifying one or more dimensions in which they may be\n",
        "padded."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kycwO0JooSHn"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.range(100)\n",
        "dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))\n",
        "dataset = dataset.padded_batch(4, padded_shapes=(None,))\n",
        "\n",
        "for batch in dataset.take(2):\n",
        "  print(batch.numpy())\n",
        "  print()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wl3yhth1oSHm"
      },
      "source": [
        "The `Dataset.padded_batch` transformation allows you to set different padding\n",
        "for each dimension of each component, and it may be variable-length (signified\n",
        "by `None` in the example above) or constant-length. It is also possible to\n",
        "override the padding value, which defaults to 0.\n",
        "\n",
        "<!--\n",
        "TODO(mrry): Add this section.\n",
        "\n",
        "### Dense ragged -> tf.SparseTensor\n",
        "-->\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G8zbAxMwoSHl"
      },
      "source": [
        "## Training workflows\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UnlhzF_AoSHk"
      },
      "source": [
        "### Processing multiple epochs\n",
        "\n",
        "The `tf.data` API offers two main ways to process multiple epochs of the same\n",
        "data.\n",
        "\n",
        "The simplest way to iterate over a dataset in multiple epochs is to use the\n",
        "`Dataset.repeat()` transformation. First, create a dataset of titanic data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0tODHZzRoSHg"
      },
      "outputs": [],
      "source": [
        "titanic_file = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n",
        "titanic_lines = tf.data.TextLineDataset(titanic_file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LMO6mlXxoSHc"
      },
      "outputs": [],
      "source": [
        "def plot_batch_sizes(ds):\n",
        "  batch_sizes = [batch.shape[0] for batch in ds]\n",
        "  plt.bar(range(len(batch_sizes)), batch_sizes)\n",
        "  plt.xlabel('Batch number')\n",
        "  plt.ylabel('Batch size')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WfVzmqL7oSHa"
      },
      "source": [
        "Applying the `Dataset.repeat()` transformation with no arguments will repeat\n",
        "the input indefinitely.\n",
        "\n",
        "The `Dataset.repeat` transformation concatenates its\n",
        "arguments without signaling the end of one epoch and the beginning of the next\n",
        "epoch. Because of this a `Dataset.batch` applied after `Dataset.repeat` will yield batches that straddle epoch boundaries:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nZ0G1cztoSHX"
      },
      "outputs": [],
      "source": [
        "titanic_batches = titanic_lines.repeat(3).batch(128)\n",
        "plot_batch_sizes(titanic_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "moH-4gBEoSHW"
      },
      "source": [
        "If you need clear epoch separation, put `Dataset.batch` before the repeat:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wmbmdK1qoSHS"
      },
      "outputs": [],
      "source": [
        "titanic_batches = titanic_lines.batch(128).repeat(3)\n",
        "\n",
        "plot_batch_sizes(titanic_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DlEM5f9loSHR"
      },
      "source": [
        "If you would like to perform a custom computation (e.g. to collect statistics) at the end of each epoch then it's simplest to restart the dataset iteration on each epoch:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YyekyeY7oSHO"
      },
      "outputs": [],
      "source": [
        "epochs = 3\n",
        "dataset = titanic_lines.batch(128)\n",
        "\n",
        "for epoch in range(epochs):\n",
        "  for batch in dataset:\n",
        "    print(batch.shape)\n",
        "  print(\"End of epoch: \", epoch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Bci79WCoSHN"
      },
      "source": [
        "### Randomly shuffling input data\n",
        "\n",
        "The `Dataset.shuffle()` transformation maintains a fixed-size\n",
        "buffer and chooses the next element uniformly at random from that buffer.\n",
        "\n",
        "Note: While large buffer_sizes shuffle more thoroughly, they can take a lot of memory, and significant time to fill. Consider using `Dataset.interleave` across files if this becomes a problem."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6YvXr-qeoSHL"
      },
      "source": [
        "Add an index to the dataset so you can see the effect:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Io4iJH1toSHI"
      },
      "outputs": [],
      "source": [
        "lines = tf.data.TextLineDataset(titanic_file)\n",
        "counter = tf.data.experimental.Counter()\n",
        "\n",
        "dataset = tf.data.Dataset.zip((counter, lines))\n",
        "dataset = dataset.shuffle(buffer_size=100)\n",
        "dataset = dataset.batch(20)\n",
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T6tNYRcsoSHH"
      },
      "source": [
        "Since the `buffer_size` is 100, and the batch size is 20, the first batch contains no elements with an index over 120."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ayM3FFFAoSHC"
      },
      "outputs": [],
      "source": [
        "n,line_batch = next(iter(dataset))\n",
        "print(n.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PLrfIjTHoSHB"
      },
      "source": [
        "As with `Dataset.batch` the order relative to `Dataset.repeat` matters.\n",
        "\n",
        "`Dataset.shuffle` doesn't signal the end of an epoch until the shuffle buffer is empty. So a shuffle placed before a repeat will show every element of one epoch before moving to the next: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YX3pe7zZoSG6"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.zip((counter, lines))\n",
        "shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)\n",
        "\n",
        "print(\"Here are the item ID's near the epoch boundary:\\n\")\n",
        "for n, line_batch in shuffled.skip(60).take(5):\n",
        "  print(n.numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H9hlE-lGoSGz"
      },
      "outputs": [],
      "source": [
        "shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]\n",
        "plt.plot(shuffle_repeat, label=\"shuffle().repeat()\")\n",
        "plt.ylabel(\"Mean item ID\")\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UucIgCxWoSGx"
      },
      "source": [
        "But a repeat before a shuffle mixes the epoch boundaries together:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bhxb5YGZoSGm"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.zip((counter, lines))\n",
        "shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)\n",
        "\n",
        "print(\"Here are the item ID's near the epoch boundary:\\n\")\n",
        "for n, line_batch in shuffled.skip(55).take(15):\n",
        "  print(n.numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VAM4cbpZoSGL"
      },
      "outputs": [],
      "source": [
        "repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]\n",
        "\n",
        "plt.plot(shuffle_repeat, label=\"shuffle().repeat()\")\n",
        "plt.plot(repeat_shuffle, label=\"repeat().shuffle()\")\n",
        "plt.ylabel(\"Mean item ID\")\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ianlfbrxbDco"
      },
      "source": [
        "## Preprocessing data\n",
        "\n",
        "The `Dataset.map(f)` transformation produces a new dataset by applying a given\n",
        "function `f` to each element of the input dataset. It is based on the\n",
        "[`map()`](https://en.wikipedia.org/wiki/Map_\\(higher-order_function\\)) function\n",
        "that is commonly applied to lists (and other structures) in functional\n",
        "programming languages. The function `f` takes the `tf.Tensor` objects that\n",
        "represent a single element in the input, and returns the `tf.Tensor` objects\n",
        "that will represent a single element in the new dataset. Its implementation uses\n",
        "standard TensorFlow operations to transform one element into another.\n",
        "\n",
        "This section covers common examples of how to use `Dataset.map()`.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UXw1IZVdbDcq"
      },
      "source": [
        "### Decoding image data and resizing it\n",
        "\n",
        "<!-- TODO(markdaoust): link to image augmentation when it exists -->\n",
        "When training a neural network on real-world image data, it is often necessary\n",
        "to convert images of different sizes to a common size, so that they may be\n",
        "batched into a fixed size.\n",
        "\n",
        "Rebuild the flower filenames dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rMGlj8V-u-NH"
      },
      "outputs": [],
      "source": [
        "list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GyhZLB8N5jBm"
      },
      "source": [
        "Write a function that manipulates the dataset elements."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fZObC0debDcr"
      },
      "outputs": [],
      "source": [
        "# Reads an image from a file, decodes it into a dense tensor, and resizes it\n",
        "# to a fixed shape.\n",
        "def parse_image(filename):\n",
        "  parts = tf.strings.split(filename, os.sep)\n",
        "  label = parts[-2]\n",
        "\n",
        "  image = tf.io.read_file(filename)\n",
        "  image = tf.image.decode_jpeg(image)\n",
        "  image = tf.image.convert_image_dtype(image, tf.float32)\n",
        "  image = tf.image.resize(image, [128, 128])\n",
        "  return image, label"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e0dVJlCA5qHA"
      },
      "source": [
        "Test that it works."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y8xuN_HBzGup"
      },
      "outputs": [],
      "source": [
        "file_path = next(iter(list_ds))\n",
        "image, label = parse_image(file_path)\n",
        "\n",
        "def show(image, label):\n",
        "  plt.figure()\n",
        "  plt.imshow(image)\n",
        "  plt.title(label.numpy().decode('utf-8'))\n",
        "  plt.axis('off')\n",
        "\n",
        "show(image, label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d3P8N-S55vDu"
      },
      "source": [
        "Map it over the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SzO8LI_H5Sk_"
      },
      "outputs": [],
      "source": [
        "images_ds = list_ds.map(parse_image)\n",
        "\n",
        "for image, label in images_ds.take(2):\n",
        "  show(image, label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Ff7IqB9bDcs"
      },
      "source": [
        "### Applying arbitrary Python logic\n",
        "\n",
        "For performance reasons, use TensorFlow operations for\n",
        "preprocessing your data whenever possible. However, it is sometimes useful to\n",
        "call external Python libraries when parsing your input data. You can use the `tf.py_function()` operation in a `Dataset.map()` transformation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R2u7CeA67DU8"
      },
      "source": [
        "For example, if you want to apply a random rotation, the `tf.image` module only has `tf.image.rot90`, which is not very useful for image augmentation. \n",
        "\n",
        "Note: `tensorflow_addons` has a TensorFlow compatible `rotate` in `tensorflow_addons.image.rotate`.\n",
        "\n",
        "To demonstrate `tf.py_function`, try using the `scipy.ndimage.rotate` function instead:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tBUmbERt7Czz"
      },
      "outputs": [],
      "source": [
        "import scipy.ndimage as ndimage\n",
        "\n",
        "def random_rotate_image(image):\n",
        "  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)\n",
        "  return image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_wEyL7bS9S6t"
      },
      "outputs": [],
      "source": [
        "image, label = next(iter(images_ds))\n",
        "image = random_rotate_image(image)\n",
        "show(image, label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KxVx7z-ABNyq"
      },
      "source": [
        "To use this function with `Dataset.map` the same caveats apply as with `Dataset.from_generator`, you need to describe the return shapes and types when you apply the function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cn2nIu92BMp0"
      },
      "outputs": [],
      "source": [
        "def tf_random_rotate_image(image, label):\n",
        "  im_shape = image.shape\n",
        "  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])\n",
        "  image.set_shape(im_shape)\n",
        "  return image, label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bWPqKbTnbDct"
      },
      "outputs": [],
      "source": [
        "rot_ds = images_ds.map(tf_random_rotate_image)\n",
        "\n",
        "for image, label in rot_ds.take(2):\n",
        "  show(image, label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ykx59-cMBwOT"
      },
      "source": [
        "### Parsing `tf.Example` protocol buffer messages\n",
        "\n",
        "Many input pipelines extract `tf.train.Example` protocol buffer messages from a\n",
        "TFRecord format. Each `tf.train.Example` record contains one or more \"features\",\n",
        "and the input pipeline typically converts these features into tensors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6wnE134b32KY"
      },
      "outputs": [],
      "source": [
        "fsns_test_file = tf.keras.utils.get_file(\"fsns.tfrec\", \"https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001\")\n",
        "dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])\n",
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HGypdgYOlXZz"
      },
      "source": [
        "You can work with `tf.train.Example` protos outside of a `tf.data.Dataset` to understand the data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4znsVNqnF73C"
      },
      "outputs": [],
      "source": [
        "raw_example = next(iter(dataset))\n",
        "parsed = tf.train.Example.FromString(raw_example.numpy())\n",
        "\n",
        "feature = parsed.features.feature\n",
        "raw_img = feature['image/encoded'].bytes_list.value[0]\n",
        "img = tf.image.decode_png(raw_img)\n",
        "plt.imshow(img)\n",
        "plt.axis('off')\n",
        "_ = plt.title(feature[\"image/text\"].bytes_list.value[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cwzqp8IGC_vQ"
      },
      "outputs": [],
      "source": [
        "raw_example = next(iter(dataset))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y2X1dQNfC8Lu"
      },
      "outputs": [],
      "source": [
        "def tf_parse(eg):\n",
        "  example = tf.io.parse_example(\n",
        "      eg[tf.newaxis], {\n",
        "          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),\n",
        "          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)\n",
        "      })\n",
        "  return example['image/encoded'][0], example['image/text'][0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lGJhKDp_61A_"
      },
      "outputs": [],
      "source": [
        "img, txt = tf_parse(raw_example)\n",
        "print(txt.numpy())\n",
        "print(repr(img.numpy()[:20]), \"...\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8vFIUFzD5qIC"
      },
      "outputs": [],
      "source": [
        "decoded = dataset.map(tf_parse)\n",
        "decoded"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vRYNYkEej7Ix"
      },
      "outputs": [],
      "source": [
        "image_batch, text_batch = next(iter(decoded.batch(10)))\n",
        "image_batch.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ry1n0UBeczit"
      },
      "source": [
        "<a id=\"time_series_windowing\"></a>\n",
        "\n",
        "### Time series windowing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t0JMgvXEz9y1"
      },
      "source": [
        "For an end to end time series example see: [Time series forecasting](../../tutorials/structured_data/time_series.ipynb)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hzBABBkAkkVJ"
      },
      "source": [
        "Time series data is often organized with the time axis intact.\n",
        "\n",
        "Use a simple `Dataset.range` to demonstrate:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kTQgo49skjuY"
      },
      "outputs": [],
      "source": [
        "range_ds = tf.data.Dataset.range(100000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o6GLGhxgpazJ"
      },
      "source": [
        "Typically, models based on this sort of data will want a contiguous time slice. \n",
        "\n",
        "The simplest approach would be to batch the data:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ETqB7QvTCNty"
      },
      "source": [
        "#### Using `batch`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pSs9XqwQpvIN"
      },
      "outputs": [],
      "source": [
        "batches = range_ds.batch(10, drop_remainder=True)\n",
        "\n",
        "for batch in batches.take(5):\n",
        "  print(batch.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mgb2qikEtk5W"
      },
      "source": [
        "Or to make dense predictions one step into the future, you might shift the features and labels by one step relative to each other:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "47XfwPhetkIN"
      },
      "outputs": [],
      "source": [
        "def dense_1_step(batch):\n",
        "  # Shift features and labels one step relative to each other.\n",
        "  return batch[:-1], batch[1:]\n",
        "\n",
        "predict_dense_1_step = batches.map(dense_1_step)\n",
        "\n",
        "for features, label in predict_dense_1_step.take(3):\n",
        "  print(features.numpy(), \" => \", label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DjsXuINKqsS_"
      },
      "source": [
        "To predict a whole window instead of a fixed offset you can split the batches into two parts:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FMmkQB1Gqo6x"
      },
      "outputs": [],
      "source": [
        "batches = range_ds.batch(15, drop_remainder=True)\n",
        "\n",
        "def label_next_5_steps(batch):\n",
        "  return (batch[:-5],   # Inputs: All except the last 5 steps\n",
        "          batch[-5:])   # Labels: The last 5 steps\n",
        "\n",
        "predict_5_steps = batches.map(label_next_5_steps)\n",
        "\n",
        "for features, label in predict_5_steps.take(3):\n",
        "  print(features.numpy(), \" => \", label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5a611Qr3jlhl"
      },
      "source": [
        "To allow some overlap between the features of one batch and the labels of another, use `Dataset.zip`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "11dF3wyFjk2J"
      },
      "outputs": [],
      "source": [
        "feature_length = 10\n",
        "label_length = 3\n",
        "\n",
        "features = range_ds.batch(feature_length, drop_remainder=True)\n",
        "labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:label_length])\n",
        "\n",
        "predicted_steps = tf.data.Dataset.zip((features, labels))\n",
        "\n",
        "for features, label in predicted_steps.take(5):\n",
        "  print(features.numpy(), \" => \", label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "adew3o2mCURC"
      },
      "source": [
        "#### Using `window`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fF6pEdlduq8E"
      },
      "source": [
        "While using `Dataset.batch` works, there are situations where you may need finer control. The `Dataset.window` method gives you complete control, but requires some care: it returns a `Dataset` of `Datasets`. See [Dataset structure](#dataset_structure) for details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZEI2W_EBw2OX"
      },
      "outputs": [],
      "source": [
        "window_size = 5\n",
        "\n",
        "windows = range_ds.window(window_size, shift=1)\n",
        "for sub_ds in windows.take(5):\n",
        "  print(sub_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r82hWdk4x-46"
      },
      "source": [
        "The `Dataset.flat_map` method can take a dataset of datasets and flatten it into a single dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SB8AI03mnF8u"
      },
      "outputs": [],
      "source": [
        " for x in windows.flat_map(lambda x: x).take(30):\n",
        "   print(x.numpy(), end=' ')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sgLIwq9Anc34"
      },
      "source": [
        "In nearly all cases, you will want to `.batch` the dataset first:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5j_y84rmyVQa"
      },
      "outputs": [],
      "source": [
        "def sub_to_batch(sub):\n",
        "  return sub.batch(window_size, drop_remainder=True)\n",
        "\n",
        "for example in windows.flat_map(sub_to_batch).take(5):\n",
        "  print(example.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hVugrmND3Grp"
      },
      "source": [
        "Now, you can see that the `shift` argument controls how much each window moves over.\n",
        "\n",
        "Putting this together you might write this function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LdFRv_0D4FqW"
      },
      "outputs": [],
      "source": [
        "def make_window_dataset(ds, window_size=5, shift=1, stride=1):\n",
        "  windows = ds.window(window_size, shift=shift, stride=stride)\n",
        "\n",
        "  def sub_to_batch(sub):\n",
        "    return sub.batch(window_size, drop_remainder=True)\n",
        "\n",
        "  windows = windows.flat_map(sub_to_batch)\n",
        "  return windows\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-iVxcVfEdf5b"
      },
      "outputs": [],
      "source": [
        "ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)\n",
        "\n",
        "for example in ds.take(10):\n",
        "  print(example.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fMGMTPQ4w8pr"
      },
      "source": [
        "Then it's easy to extract labels, as before:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F0fPfZkZw6j_"
      },
      "outputs": [],
      "source": [
        "dense_labels_ds = ds.map(dense_1_step)\n",
        "\n",
        "for inputs,labels in dense_labels_ds.take(3):\n",
        "  print(inputs.numpy(), \"=>\", labels.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vyi_-ft0kvy4"
      },
      "source": [
        "### Resampling\n",
        "\n",
        "When working with a dataset that is very class-imbalanced, you may want to resample the dataset. `tf.data` provides two methods to do this. The credit card fraud dataset is a good example of this sort of problem.\n",
        "\n",
        "Note: See [Imbalanced Data](../tutorials/keras/imbalanced_data.ipynb) for a full tutorial.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U2e8dxVUlFHO"
      },
      "outputs": [],
      "source": [
        "zip_path = tf.keras.utils.get_file(\n",
        "    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',\n",
        "    fname='creditcard.zip',\n",
        "    extract=True)\n",
        "\n",
        "csv_path = zip_path.replace('.zip', '.csv')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EhkkM4Wx75S_"
      },
      "outputs": [],
      "source": [
        "creditcard_ds = tf.data.experimental.make_csv_dataset(\n",
        "    csv_path, batch_size=1024, label_name=\"Class\",\n",
        "    # Set the column types: 30 floats and an int.\n",
        "    column_defaults=[float()]*30+[int()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8O47EmHlxYX"
      },
      "source": [
        "Now, check the distribution of classes, it is highly skewed:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a8-Ss69XlzXD"
      },
      "outputs": [],
      "source": [
        "def count(counts, batch):\n",
        "  features, labels = batch\n",
        "  class_1 = labels == 1\n",
        "  class_1 = tf.cast(class_1, tf.int32)\n",
        "\n",
        "  class_0 = labels == 0\n",
        "  class_0 = tf.cast(class_0, tf.int32)\n",
        "\n",
        "  counts['class_0'] += tf.reduce_sum(class_0)\n",
        "  counts['class_1'] += tf.reduce_sum(class_1)\n",
        "\n",
        "  return counts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O1a3t_B4l_f6"
      },
      "outputs": [],
      "source": [
        "counts = creditcard_ds.take(10).reduce(\n",
        "    initial_state={'class_0': 0, 'class_1': 0},\n",
        "    reduce_func = count)\n",
        "\n",
        "counts = np.array([counts['class_0'].numpy(),\n",
        "                   counts['class_1'].numpy()]).astype(np.float32)\n",
        "\n",
        "fractions = counts/counts.sum()\n",
        "print(fractions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z1b8lFhSnDdv"
      },
      "source": [
        "A common approach to training with an imbalanced dataset is to balance it. `tf.data` includes a few methods which enable this workflow:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y8jQWsgMnjQG"
      },
      "source": [
        "#### Datasets sampling"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ov14SRrQyQE3"
      },
      "source": [
        "One approach to resampling a dataset is to use `sample_from_datasets`. This is more applicable when you have a separate `data.Dataset` for each class.\n",
        "\n",
        "Here, just use filter to generate them from the credit card fraud data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6YKfCPa-nioA"
      },
      "outputs": [],
      "source": [
        "negative_ds = (\n",
        "  creditcard_ds\n",
        "    .unbatch()\n",
        "    .filter(lambda features, label: label==0)\n",
        "    .repeat())\n",
        "positive_ds = (\n",
        "  creditcard_ds\n",
        "    .unbatch()\n",
        "    .filter(lambda features, label: label==1)\n",
        "    .repeat())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8FNd3sQjzl9-"
      },
      "outputs": [],
      "source": [
        "for features, label in positive_ds.batch(10).take(1):\n",
        "  print(label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GxLAr-7p0ATX"
      },
      "source": [
        "To use `tf.data.experimental.sample_from_datasets` pass the datasets, and the weight for each:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vjdPVIFCngOb"
      },
      "outputs": [],
      "source": [
        "balanced_ds = tf.data.experimental.sample_from_datasets(\n",
        "    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2K4ObOms082B"
      },
      "source": [
        "Now the dataset produces examples of each class with 50/50 probability:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Myvkw21Rz-fH"
      },
      "outputs": [],
      "source": [
        "for features, labels in balanced_ds.take(10):\n",
        "  print(labels.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OUTE3eb9nckY"
      },
      "source": [
        "#### Rejection resampling"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kZ9ezkK6irMD"
      },
      "source": [
        "One problem with the above `experimental.sample_from_datasets` approach is that\n",
        "it needs a separate `tf.data.Dataset` per class. Using `Dataset.filter`\n",
        "works, but results in all the data being loaded twice.\n",
        "\n",
        "The `data.experimental.rejection_resample` function can be applied to a dataset to rebalance it, while only loading it once. Elements will be dropped from the dataset to achieve balance.\n",
        "\n",
        "`data.experimental.rejection_resample` takes a `class_func` argument. This `class_func` is applied to each dataset element, and is used to determine which class an example belongs to for the purposes of balancing.\n",
        "\n",
        "The elements of `creditcard_ds` are already `(features, label)` pairs. So the `class_func` just needs to return those labels:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zC_Cuzw8lhI5"
      },
      "outputs": [],
      "source": [
        "def class_func(features, label):\n",
        "  return label"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DdKmE8Jumlp0"
      },
      "source": [
        "The resampler also needs a target distribution, and optionally an initial distribution estimate:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9tv0tWNxmkzM"
      },
      "outputs": [],
      "source": [
        "resampler = tf.data.experimental.rejection_resample(\n",
        "    class_func, target_dist=[0.5, 0.5], initial_dist=fractions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YxJrOZVToGuE"
      },
      "source": [
        "The resampler deals with individual examples, so you must `unbatch` the dataset before applying the resampler:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fY6VIhr3oGHG"
      },
      "outputs": [],
      "source": [
        "resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L-HnC1s8idqV"
      },
      "source": [
        "The resampler returns creates `(class, example)` pairs from the output of the `class_func`. In this case, the `example` was already a `(feature, label)` pair, so use `map` to drop the extra copy of the labels:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KpfCGU6BiaZq"
      },
      "outputs": [],
      "source": [
        "balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j3d2jyEhx9kD"
      },
      "source": [
        "Now the dataset produces examples of each class with 50/50 probability:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XGLYChBQwkDV"
      },
      "outputs": [],
      "source": [
        "for features, labels in balanced_ds.take(10):\n",
        "  print(labels.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vYFKQx3bUBeU"
      },
      "source": [
        "## Iterator Checkpointing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SOGg1UFhUE4z"
      },
      "source": [
        "Tensorflow supports [taking checkpoints](https://www.tensorflow.org/guide/checkpoint) so that when your training process restarts it can restore the latest checkpoint to recover most of its progress. In addition to checkpointing the model variables, you can also checkpoint the progress of the dataset iterator. This could be useful if you have a large dataset and don't want to start the dataset from the beginning on each restart. Note however that iterator checkpoints may be large, since transformations such as `shuffle` and `prefetch` require buffering elements within the iterator. \n",
        "\n",
        "To include your iterator in a checkpoint, pass the iterator to the `tf.train.Checkpoint` constructor."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Fsm9wvKUsNC"
      },
      "outputs": [],
      "source": [
        "range_ds = tf.data.Dataset.range(20)\n",
        "\n",
        "iterator = iter(range_ds)\n",
        "ckpt = tf.train.Checkpoint(step=tf.Variable(0), iterator=iterator)\n",
        "manager = tf.train.CheckpointManager(ckpt, '/tmp/my_ckpt', max_to_keep=3)\n",
        "\n",
        "print([next(iterator).numpy() for _ in range(5)])\n",
        "\n",
        "save_path = manager.save()\n",
        "\n",
        "print([next(iterator).numpy() for _ in range(5)])\n",
        "\n",
        "ckpt.restore(manager.latest_checkpoint)\n",
        "\n",
        "print([next(iterator).numpy() for _ in range(5)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gxWglTwX9Fex"
      },
      "source": [
        "Note: It is not possible to checkpoint an iterator which relies on external state such as a `tf.py_function`. Attempting to do so will raise an exception complaining about the external state."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uLRdedPpbDdD"
      },
      "source": [
        "## Using tf.data with tf.keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JTQe8daMcgFz"
      },
      "source": [
        "The `tf.keras` API simplifies many aspects of creating and executing machine\n",
        "learning models. Its `.fit()` and `.evaluate()` and `.predict()` APIs support datasets as inputs. Here is a quick dataset and model setup:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-bfjqm0hOfES"
      },
      "outputs": [],
      "source": [
        "train, test = tf.keras.datasets.fashion_mnist.load_data()\n",
        "\n",
        "images, labels = train\n",
        "images = images/255.0\n",
        "labels = labels.astype(np.int32)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wDhF3rGnbDdD"
      },
      "outputs": [],
      "source": [
        "fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))\n",
        "fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)\n",
        "\n",
        "model = tf.keras.Sequential([\n",
        "  tf.keras.layers.Flatten(),\n",
        "  tf.keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "model.compile(optimizer='adam',\n",
        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), \n",
        "              metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rdogg8CfHs-G"
      },
      "source": [
        " Passing a dataset of `(feature, label)` pairs is all that's needed for `Model.fit` and `Model.evaluate`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9cu4kPzOHnlt"
      },
      "outputs": [],
      "source": [
        "model.fit(fmnist_train_ds, epochs=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FzpAQfJMJF41"
      },
      "source": [
        "If you pass an infinite dataset, for example by calling `Dataset.repeat()`, you just need to also pass the `steps_per_epoch` argument:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bp1BpzlyJinb"
      },
      "outputs": [],
      "source": [
        "model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iTLsw_nqJpTw"
      },
      "source": [
        "For evaluation you can pass the number of evaluation steps:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TnlRHlaL-XUI"
      },
      "outputs": [],
      "source": [
        "loss, accuracy = model.evaluate(fmnist_train_ds)\n",
        "print(\"Loss :\", loss)\n",
        "print(\"Accuracy :\", accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C8UBU3CJKEA4"
      },
      "source": [
        "For long datasets, set the number of steps to evaluate:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uVgamf9HKDon"
      },
      "outputs": [],
      "source": [
        "loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)\n",
        "print(\"Loss :\", loss)\n",
        "print(\"Accuracy :\", accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aZYhJ_YSIl6w"
      },
      "source": [
        "The labels are not required in when calling `Model.predict`. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "343lXJ-pIqWD"
      },
      "outputs": [],
      "source": [
        "predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)\n",
        "result = model.predict(predict_ds, steps = 10)\n",
        "print(result.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YfzZORwLI202"
      },
      "source": [
        "But the labels are ignored if you do pass a dataset containing them:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mgQJTPrT-2WF"
      },
      "outputs": [],
      "source": [
        "result = model.predict(fmnist_train_ds, steps = 10)\n",
        "print(result.shape)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "data.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
